''''
Answer question 1: How well can states derived from the state features predict behavior after choosing activities?
Author: Meng Zhang
Date: January 2024
Disclaimer: adapted from the analysis code https://doi.org/10.4121/22153898.v1

Input: RL_trasition_weighted_reward.csv
Output: Figure 2
'''

import numpy as np
import pandas as pd
import itertools
import scipy
import seaborn as sns
import matplotlib.pyplot as plt
import Utils as util
import Calculate_Q_values as cal_q
import Calculate_Actions as cal_act




### LOAD DATA
NUM_ACTIONS = 14
df_weighted, reward_mean, min, max = util.weighted_sum_of_reward__for_transitions(0.5)
data = pd.read_csv("RL_trasition_weighted_reward.csv", converters={'Binary_State': eval,'Binary_State_Next_Session': eval})
all_people = list(set(data['rand_id'].tolist()))
NUM_PEOPLE = len(all_people)
print("Total number of samples: " + str(len(data)) + ".")
print("Total number of people: " + str(NUM_PEOPLE) + ".")

### FEATURE SELECTION
# select 3 out of 11 candidate features:
# 0: usefulness belief about self-efficacy
# 1: usefulness belief about practical knowledge
# 3: usefulness belief about awareness of positive outcomes
# 4: usefulness belief about awareness of negative outcomes
# 5: usefulness belief about motivation to change
# 6: usefulness belief about knowledge of how to maintain/achieve mental well-being
# 7: usefulness belief about mindset that physical activity helps to quit smoking
# 8: usefulness belief about awareness of smoking patterns
# 9: usefulness belief about knowledge of how to maintain/achieve well-being
NUM_FEAT_TO_SELECT = 3
OUTPUT_LOWER = -1
OUTPUT_HIGHER = 1
CANDIDATE_FEATURES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
data_train = data.copy(deep=True)
map_to_rewards = util.get_map_effort_reward(reward_mean, OUTPUT_LOWER, OUTPUT_HIGHER, min, max)
df_feat = data_train.drop(columns=['rand_id', 'session_num'])
data_feat = df_feat.values.tolist()
feat_sel = cal_q.feature_selection(data_feat, reward_mean, min, max, CANDIDATE_FEATURES, NUM_FEAT_TO_SELECT)
print("Features selected:", feat_sel, "Weighted_mean", reward_mean)



### Compute mean L1-error per state for two approaches
### APPROACH 1: MEAN REWARD PER ACTION
reward_error_per_state_actions_only = []  #L1-error for reward across sessions per state
for state in range(2 ** NUM_FEAT_TO_SELECT):
    reward_error_per_state_actions_only.append([])

all_states = [list(i) for i in itertools.product([0, 1], repeat = NUM_FEAT_TO_SELECT)]

for i in range(NUM_PEOPLE):
    # user all the data except samples from the current person
    data_train_temp = data_train[data_train['rand_id'] != all_people[i]]
    data_train_temp_samples = data_train_temp[["Binary_State", "Binary_State_Next_Session", "cluster_new_index", "weighted_reward"]].values.tolist()

    if i % 100 == 0:
        print(".... Person" + str(i))
    # train
    reward_avg_train = cal_act.compute_avg_reward(data_train_temp_samples, reward_mean, num_act=NUM_ACTIONS)

    # test on the left out
    data_person = data[data["rand_id"]== all_people[i]]

    for j in range(1, len(data_person)+1):
        data_person_session = data_person[data_person['session_num'] == j]
        action = data_person_session.iloc[0]['cluster_new_index']
        reward = data_person_session.iloc[0]['weighted_reward']
        predicted_reward = reward_avg_train[int(action)]
        actual_reward =util.map_efforts_to_rewards([reward], map_to_rewards)[0]
        l1_error = abs(predicted_reward - actual_reward)
        state = np.take(data_person_session.iloc[0]['Binary_State'], feat_sel)
        reward_error_per_state_actions_only[all_states.index(list(state))].append(l1_error)

print("\n***Approach 1: Mean reward per action***")
reward_error_per_state_actions_only_mean = [np.mean(reward_error_per_state_actions_only[state_idx]) for state_idx in range(2**NUM_FEAT_TO_SELECT)]
reward_error_per_state_actions_only_std = [np.std(reward_error_per_state_actions_only[state_idx]) for state_idx in range(2**NUM_FEAT_TO_SELECT)]
print("Mean L1-error per state:", np.round(reward_error_per_state_actions_only_mean, 2))
print("                     SD:", np.round(reward_error_per_state_actions_only_std, 2))


#### APPROACH 2: MEAN REWARD PER ACTION AND STATE
reward_error_per_state_action_state = []
for state in range(2 ** NUM_FEAT_TO_SELECT):
    reward_error_per_state_action_state.append([])

abstract_states = [list(i) for i in itertools.product([0,1], repeat = NUM_FEAT_TO_SELECT)]


for i in range(NUM_PEOPLE):
    # user all the data except samples from the current person
    data_train_temp = data_train[data_train['rand_id'] != all_people[i]]
    data_train_temp_samples = data_train_temp[["Binary_State", "Binary_State_Next_Session", "cluster_new_index", "weighted_reward"]].values.tolist()

    if i % 100 == 0:
        print("...Person " + str(i))

    # train
    _ , reward_func, trans_func, _ = cal_q.compute_q_vals_dynamics(data_train_temp_samples,
                                                                      reward_mean,
                                                                      min,
                                                                      max,
                                                                      feat_sel, num_act=NUM_ACTIONS)
    # test on person that was left out from training
    data_person = data[data['rand_id'] == all_people[i]]

    for j in range(1, len(data_person)+1):
        data_person_session = data_person[data_person["session_num"] == j]
        state = list(np.take(data_person_session.iloc[0]['Binary_State'], feat_sel))
        state_idx = abstract_states.index(state)
        action = int(data_person_session.iloc[0]['cluster_new_index'])
        effort = int(data_person_session.iloc[0]["weighted_reward"])

        predicted_reward = sum(
            [trans_func[state_idx, action, s_prime] * reward_func[state_idx, action, s_prime] for s_prime in
             range(2 ** NUM_FEAT_TO_SELECT)])
        actual_reward = util.map_efforts_to_rewards([effort], map_to_rewards)[0]
        l1_error = abs(predicted_reward - actual_reward)
        reward_error_per_state_action_state[state_idx].append(l1_error)

print("\n***Approach 2: Mean reward per action and state***")
reward_error_per_state_action_state_mean = [np.mean(reward_error_per_state_action_state[state_idx]) for state_idx in
                                            range(2 ** NUM_FEAT_TO_SELECT)]
reward_error_per_state_action_state_std = [np.std(reward_error_per_state_action_state[state_idx]) for state_idx in
                                           range(2 ** NUM_FEAT_TO_SELECT)]
print("Mean L1-error per state:", np.round(reward_error_per_state_action_state_mean, 2))
print("                     SD:", np.round(reward_error_per_state_action_state_std, 2))

# Mean reward per state and overall
reward_list = []

for state_idx in range(2 ** NUM_FEAT_TO_SELECT):
    reward_list.append([])

all_states = [list(i) for i in itertools.product([0,1], repeat= NUM_FEAT_TO_SELECT)]
for p in range(NUM_PEOPLE):
    # data for this person
    data_p = data[data['rand_id'] == all_people[p]]

    for session_idx in range(1, len(data_p)+1):
        weighted_reward = float(data_p[data_p['session_num']==session_idx]['weighted_reward'])
        reward = util.map_efforts_to_rewards([weighted_reward], map_to_rewards)[0]
        s = data_p[data_p['session_num'] == session_idx]['Binary_State'].values[0]
        s = list(np.take(s, feat_sel))
        s_idx = all_states.index(s)

        reward_list[s_idx].append(reward)

# Compute average reward per state and overall (i.e., across all states)
reward_avg_per_state = np.array([np.mean(reward_list[state_idx]) for state_idx in range(2**NUM_FEAT_TO_SELECT)])
reward_std_per_state = np.array([np.std(reward_list[state_idx]) for state_idx in range(2**NUM_FEAT_TO_SELECT)])
reward_avg = np.mean([r for state_idx in range(2**NUM_FEAT_TO_SELECT) for r in reward_list[state_idx]])
reward_std = np.std([r for state_idx in range(2**NUM_FEAT_TO_SELECT) for r in reward_list[state_idx]])

print("Mean reward per state:\n", np.round(reward_avg_per_state, 2))
print("Mean reward across states:\n", round(reward_avg, 2), "(SD = " + str(np.round(reward_std, 2)) + ")")


### CREATE FIGURE 2
sns.set()
sns.set_style("white")
med_fontsize = 22
small_fontsize = 18
extrasmall_fontsize = 15
sns.set_context("paper", rc={"font.size":small_fontsize,"axes.titlesize":med_fontsize,"axes.labelsize":med_fontsize,
                            'xtick.labelsize':small_fontsize, 'ytick.labelsize':small_fontsize,
                            'legend.fontsize':extrasmall_fontsize,'legend.title_fontsize': extrasmall_fontsize})
patterns = ["-", ""]

x_vals = np.arange(2**NUM_FEAT_TO_SELECT)
width = 0.4

fig, host = plt.subplots(figsize=(10,5))
par1 = host.twinx()

host.bar(x_vals-0.5 * width, reward_error_per_state_actions_only_mean, width, color = 'gray', hatch = patterns[0])
host.bar(x_vals+0.5 * width, reward_error_per_state_action_state_mean, width, color = 'deepskyblue',
         hatch = patterns[1])

par1.plot(x_vals, np.ones(2**NUM_FEAT_TO_SELECT) * reward_avg, color = "black", label = "Overall")
par1.plot(x_vals, reward_avg_per_state, color = "red", label = "Per state")

host.set_ylim([0, 1])
host.set_ylabel("Mean L1-Error")
host.set_xlabel("State")
host.legend(["Mean reward per action", "Mean reward per action & state"], loc = "upper left", title="Mean L1-Error", bbox_to_anchor=(-0.01,1.01))

par1.set_ylabel("Mean Reward", fontsize=med_fontsize)
par1.legend(loc="upper center", title="Mean Reward", bbox_to_anchor=(0.64,1.01))

plt.xticks(x_vals, ["000", "001", "010", "011", "100", "101", "110", "111"])

# Add Bayesian confidence intervals (i.e., credible intervals) for the mean
alpha = 0.95
for state in range(2**NUM_FEAT_TO_SELECT):
    conf_a1 = scipy.stats.bayes_mvs(reward_error_per_state_actions_only[state], alpha = alpha)[0].minmax
    host.vlines(x=state-0.5 * width, ymin=conf_a1[0], ymax=conf_a1[1], color='black')
    conf_a2 = scipy.stats.bayes_mvs(reward_error_per_state_action_state[state], alpha = alpha)[0].minmax
    host.vlines(x=state+0.5 * width, ymin=conf_a2[0], ymax=conf_a2[1], color='black')

plt.savefig("Figures/Figure_2.pdf", dpi=1500,
            bbox_inches='tight', pad_inches=0)


